%%%%%%%%%%%%%% Matrix completion with denoising example %%%%%%%%%%%%%%%%%%%
clc
clear
seed = 123;
rng(seed,'twister');

%input data

n=250;
r=10;

W = randn(n,r);
X_hat =  W*W'; % (n*n) matrix

X_hat_r = reshape(X_hat,[],1); 

L = randn(n,n);
n_hat = 0.5; %noise factor
E = n_hat* (L+ L'); % noised matrix

M = X_hat + E ; % observation matrix (n*n)

%create random observed entries

O_indx2 = (rand(n) < 0.8);
O_indx2 = triu(O_indx2)+triu(O_indx2,1)'; % select 80% of entries
O2 = M.*O_indx2; %obsereved matrix of the lower level


n2 = 1; %set it equal to 1 to normalize its effect in the objective function
n1 = n2;
O_indx1 = O_indx2; 
dim = sum(sum(O_indx2==1)); %use it to scale the obtained values for plotting figures


O2_r = vec(O2); 
O_indx2_r = vec(O_indx2); 


O_indx1_r = reshape(O_indx1,[],1); 
O1 = M.*O_indx1 ; %observed matrix of upper level


a  = norm(svd(X_hat),1);
l1 = 5e-2; %lambda_1
l2 = 5e-2; %lambda_2
delta = .9;
maxiter = 1e4; 
maxtime = 200;

X_init = sparse(n,n);  
Y_init = sparse(n,n);
%% function definition
 fun_f = @(X,Y) (1/n1)* (sum_square(X.*O_indx1_r - Y.*O_indx1_r));

%use pseudo hubor loss function to induce sparsity: L_{\delta}(a) = \delta^2*(sqrt(1+(a/delta)^2)-1)
 fun_g = @(X,Y) ((1/n2)* (sum_square(Y.*O_indx2_r - O2_r))) + l1* sum( (delta^2) * (sqrt(1+((Y/delta).^2))-1)) + l2* (sum_square(X-Y)); 

 %gradients
 grad_f_y = @(X,Y) (-2/n1) * (X.*O_indx1_r - Y.*O_indx1_r) ;
 grad_f_x = @(X,Y) (2/n1) * (X.*O_indx1_r - Y.*O_indx1_r) ;
 
 grad_g_y = @(X,Y) (2/n2) * (Y.*O_indx2_r - O2_r) + l1*(delta*Y./sqrt(Y.^2+delta^2))  + l2* -2* (X-Y); 

 %second derivatives
 grad_g_yy = @ (Y) (2/n2) * spdiags(O_indx2_r,0,n^2,n^2) + l1 * spdiags(delta^3 ./ (Y.^2+delta^2).^(1.5),0,n^2,n^2) + l2*2*speye(n^2) ;
 grad_g_yx= -l2*2* speye(n^2) ;

%% IBCG algorithm
mu_g =  ((2/n2) + (2*l2));
L_g =  (2/n2) + (l1*delta^1.5+2*l2);
eta = 0.9*(1-(L_g-mu_g)/(L_g+mu_g))/mu_g; 

gamma = 0.25/(sqrt(maxiter));
alpha = 2/(mu_g + L_g);

M_r = vec(M);
param.ind = O_indx2_r;
param.M = M_r;
param.a = a;
param.eta = eta; % stepsize for w
param.gam = gamma; % stepsize for x (FW)
param.alpha = alpha; % stepsize for y
param.maxiter = maxiter;
param.maxtime = maxtime;

[f_vec1,g_vec1,time_vec1, e1, matrix1] = IBCG(fun_f, grad_f_y,grad_f_x, grad_g_y, grad_g_yx,...
    grad_g_yy,param, X_init, Y_init, X_hat_r);
disp('IBCG Solution Achieved!');
%% TTSA
L_g_yy = L_g;
L_g_xy = l2*2;
C_g_xy = l2*2;

L_f_y = 2/n1;
L_f_x = 2/n1;
C_f_y = 2/n1 ;

param.a = a;
param.lg = L_g;
param.mug = mu_g;
param.maxiter = maxiter;
param.cg = C_g_xy;
param.lfy = L_f_y;
param.lfx = L_f_x;
param.cf = C_f_y;
param.lgyy = L_g_yy;
param.lgxy = L_g_xy;

[f_vec2,g_vec2,time_vec2, e2,matrix2] = TTSA(fun_f, grad_f_y,grad_f_x, grad_g_y, grad_g_yx,...
    grad_g_yy, param, X_init, Y_init, X_hat_r);
disp('TTSA Solution Achieved!');

%% SBFW algorithm
[f_vec3,g_vec3,time_vec3, e3,matrix3] = SBFW( fun_f, grad_f_y,grad_f_x, grad_g_y, grad_g_yx,...
    grad_g_yy, param, X_init, Y_init, X_hat_r);
disp('SFBW Solution Achieved!');

%% Figures 

figure; % upper level objective value
set(0,'defaulttextinterpreter','latex')
set(gcf,'DefaultLineLinewidth',2)
set(gcf,'DefaultLineMarkerSize',5);

In_BiCoG = semilogy(time_vec1,(f_vec1)/dim,'DisplayName','IBCG','color','blue');
hold on;
TTSA = semilogy(time_vec2,(f_vec2)/dim,'DisplayName','TTSA', 'color', 'green');
hold on;
SBFW = semilogy(time_vec3,(f_vec3)/dim,'DisplayName','SBFW', 'Color','red');

ylabel('$f(x_k,y_k) $')
xlabel('time(s)')

yticks([0,1e-3, 1e-2, 1e-1, 1e0]);
set(gca,'FontSize',24);
legend('Interpreter','latex','Location','southeast')
grid on;
grid minor;
pbaspect([1 0.8 1])

%--------------------------------------------------------------------------
figure; % norm of the gradient 
set(0,'defaulttextinterpreter','latex')
set(gcf,'DefaultLineLinewidth',2)
set(gcf,'DefaultLineMarkerSize',5);
In_BiCoG = semilogy(time_vec1,(g_vec1)/dim,'DisplayName','IBCG','color', 'blue');
hold on;
TTSA = semilogy(time_vec2,(g_vec2)/dim,'DisplayName','TTSA','color','green');
hold on;
SBFW = semilogy(time_vec3,(g_vec3)/dim,'DisplayName','SBFW', 'Color','red');


ylabel('$\|\nabla g_y(x_k,y_k)\|$')
xlabel('time(s)')

set(gca,'FontSize',24);
legend('Interpreter','latex','Location','northeast')
grid on;
grid minor
pbaspect([1 0.8 1])

%--------------------------------------------------------------------------
figure; %normalized error
set(0,'defaulttextinterpreter','latex')
set(gcf,'DefaultLineLinewidth',2)
set(gcf,'DefaultLineMarkerSize',5);

In_BiCoG = semilogy(time_vec1,e1,'DisplayName','IBCG','Color','blue');
hold on;
TTSA = semilogy(time_vec2,e2,'DisplayName','TTSA','Color','green');
hold on;
SBFW = semilogy(time_vec3, e3,'DisplayName','SBFW','Color','red' );

ylabel('$\bar{e}$')
xlabel('time(s)')

set(gca,'FontSize',24);
legend('Interpreter','latex','Location','northeast')
grid on;
grid minor;
yticks([0, 1e-1, 1e0, 1e1]);
